class ActivationsAndGradients:
    """Class for extracting activations and
    registering gradients from targetted intermediate layers"""

    def __init__(self, model, target_layers, reshape_transform):
        self.model = model
        self.gradients = []
        self.activations = []
        self.reshape_transform = reshape_transform
        self.handles = []
        for target_layer in target_layers:
            self.handles.append(
                target_layer.register_forward_hook(self.save_activation)
            )
            # Because of https://github.com/pytorch/pytorch/issues/61519,
            # we don't use backward hook to record gradients.
            self.handles.append(
                target_layer.register_forward_hook(self.save_gradient)
            )

    def save_activation(self, module, input, output):
        activation = output

        if self.reshape_transform is not None:
            activation = self.reshape_transform(activation)
        self.activations.append(activation.cpu().detach())

    def save_gradient(self, module, input, output):
        if not hasattr(output, "requires_grad") or not output.requires_grad:
            # alert that there is a layer without grad
            print(f"Warning: Layer {module} does not have gradients.")
            # You can only register hooks on tensor requires grad.
            return

        # Gradients are computed in reverse order
        def _store_grad(grad):
            if self.reshape_transform is not None:
                grad = self.reshape_transform(grad)
            self.gradients = [grad.cpu().detach()] + self.gradients

        output.register_hook(_store_grad)

    def __call__(self, x):
        self.gradients = []
        self.activations = []
        return self.model(x)

    def release(self):
        for handle in self.handles:
            handle.remove()


class TokenAAG(ActivationsAndGradients):
    """Class for extracting activations and
    registering gradients from targetted intermediate layers"""

    def __init__(self, model, target_layers):
        super().__init__(model, target_layers, reshape_transform=None)
        self.activations_in = []

    def save_activation(self, module, input, output):
        activation = output
        activation_in = input[0]

        for i in input:
            print(f"Input Components: \n\t{i}")

        if self.reshape_transform is not None:
            activation_in = self.token_reshape_transform(activation_in)
            activation = self.token_reshape_transform(activation)

        self.activations.append(activation.cpu().detach())
        self.activations_in.append(activation_in.cpu().detach())

    def token_reshape_transform(self, tensor, height=14, width=14):
        """
        Args:
            tensor (torch.Tensor): The input tensor of shape  [B, L, D], where N is the batch size,
                                   T is the number of tokens, and C is the number of channels.
            height (int, optional): The height to reshape the tokens into. Default is 14.
            width (int, optional): The width to reshape the tokens into. Default is 14.
        Returns:
            torch.Tensor: The reshaped and transposed tensor of shape [B, L, D]
        """
        return tensor
